#!/usr/bin/env python3
import argparse
import glob
import json
import os
import re
from typing import List

import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, mean_absolute_percentage_error

# --------------------------------------------------
# Helpers
# --------------------------------------------------

def load_ctr_map(path: str) -> dict:
    with open(path, "r", encoding="utf-8") as f:
        return {str(k): float(v) for k, v in json.load(f).items()}


def collect_results(result_files: List[str], ctr_map: dict):
    gt_vals, pred_vals = [], []  # mapped CTR values
    gt_pct_vals, pred_pct_vals = [], []  # integer percentile labels (0-100)

    for fp in result_files:
        try:
            data = json.load(open(fp, "r", encoding="utf-8"))
        except Exception:
            continue

        for rec in data:
            gt_key = rec.get("ground_truth")
            pred_key = rec.get("avg_predicted_score", rec.get("predicted_score"))

            # round predicted if float
            if pred_key is None or gt_key is None:
                continue

            try:
                # Round to nearest integer first
                gt_int = int(round(float(gt_key)))
                pr_int = int(round(float(pred_key)))
            except Exception:
                # Skip records that cannot be parsed as numeric
                continue

            # Prefer plain integer key; fall back to "<int>.0" form used by some mapping files
            def resolve_key(k_int: int):
                k_plain = str(k_int)
                if k_plain in ctr_map:
                    return k_plain
                k_dot = f"{k_int}.0"
                if k_dot in ctr_map:
                    return k_dot
                return None

            gt_key_resolved = resolve_key(gt_int)
            pr_key_resolved = resolve_key(pr_int)

            if gt_key_resolved is None or pr_key_resolved is None:
                # Mapping not found for either ground truth or prediction
                continue

            # Store mapped CTR values
            gt_vals.append(ctr_map[gt_key_resolved])
            pred_vals.append(ctr_map[pr_key_resolved])

            # Also store the raw percentile labels
            gt_pct_vals.append(gt_int)
            pred_pct_vals.append(pr_int)

    return pd.DataFrame({
        "ctr_y": gt_vals,          # continuous CTR values
        "socia_pred": pred_vals,   # continuous CTR predictions
        "pct_y": gt_pct_vals,      # integer percentiles (ground truth)
        "pct_pred": pred_pct_vals  # integer percentiles (prediction)
    })


def compute_bucket_accuracy(df: pd.DataFrame):
    """Compute 3-way (tertile) and 10-way (decile) accuracy based on *percentile labels*.

    The DataFrame must contain integer percentiles in columns `pct_y` (ground truth)
    and `pct_pred` (model prediction).
    """

    # Fallback: if percentile columns are missing, revert to CTR values (old behaviour)
    if "pct_y" in df and "pct_pred" in df:
        gt = df["pct_y"]
        pred = df["pct_pred"]
    else:
        gt = df["ctr_y"]
        pred = df["socia_pred"]

    # 3-bucket
    qs3 = gt.quantile([1 / 3, 2 / 3]).values
    gt_c3 = np.digitize(gt, bins=[-np.inf, qs3[0], qs3[1], np.inf])
    pr_c3 = np.digitize(pred, bins=[-np.inf, qs3[0], qs3[1], np.inf])
    acc3 = accuracy_score(gt_c3, pr_c3)

    # 10-bucket
    qs10 = gt.quantile(np.arange(0.1, 1.0, 0.1)).values.tolist()
    bins10 = [-np.inf] + qs10 + [np.inf]
    gt_c10 = np.digitize(gt, bins=bins10)
    pr_c10 = np.digitize(pred, bins=bins10)
    acc10 = accuracy_score(gt_c10, pr_c10)

    return acc3, acc10


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--results_dir", required=True, help="Directory containing slice/merged JSON result files")
    ap.add_argument("--ctr_map", required=True, help="JSON file mapping ID -> real CTR value")
    ap.add_argument("--out_file", default=None, help="Path to save the metric summary (default: <results_dir>/metrics.txt)")
    args = ap.parse_args()

    ctr_map = load_ctr_map(args.ctr_map)

    result_files = sorted(glob.glob(os.path.join(args.results_dir, "**", "*.json"), recursive=True))
    if not result_files:
        print("[ERROR] No JSON result files found in", args.results_dir)
        return

    # Aggregate dataframe for overall metrics
    df_all = collect_results(result_files, ctr_map)

    if df_all.empty:
        print("[ERROR] No valid records after mapping – check files/mapping.")
        return

    # Prepare output lines list
    out_lines: List[str] = []

    def metric_line(name: str, df: pd.DataFrame):
        y_t = df["ctr_y"]
        y_p = df["socia_pred"]
        rm = np.sqrt(mean_squared_error(y_t, y_p))
        r2v = r2_score(y_t, y_p)

        # Mean Absolute Percentage Error (sklearn) – exclude rows where y_t == 0
        non_zero_mask = y_t != 0
        if non_zero_mask.any():
            mape_val = mean_absolute_percentage_error(y_t[non_zero_mask], y_p[non_zero_mask]) * 100
        else:
            mape_val = float('nan')

        # Element-wise percentage errors (still needed for the p10 / p20 / p30 bands)
        pe = 100 * np.abs(y_t - y_p) / y_t
        df_pe = pe.replace([np.inf, -np.inf], np.nan).dropna().to_frame("ctr")
        p10v = (df_pe["ctr"] < 10).mean()
        p20v = (df_pe["ctr"] < 20).mean()
        p30v = (df_pe["ctr"] < 30).mean()
        a3, a10 = compute_bucket_accuracy(df)
        return f"{name}\tN:{len(df)}\tRMSE:{rm:.4f}\tMAPE:{mape_val:.2f}%\tR2:{r2v:.4f}\tp10:{p10v:.3f}\tp20:{p20v:.3f}\tp30:{p30v:.3f}\tAcc3:{a3:.3f}\tAcc10:{a10:.3f}"

    # Per-file metrics
    for fp in result_files:
        df_f = collect_results([fp], ctr_map)
        if df_f.empty:
            continue
        out_lines.append(metric_line(os.path.basename(fp), df_f))

    # Overall metrics line
    out_lines.append("-"*80)
    out_lines.append(metric_line("OVERALL", df_all))

    print("\n".join(out_lines))

    out_path = args.out_file or os.path.join(args.results_dir, "metrics.txt")
    try:
        with open(out_path, "w", encoding="utf-8") as f_out:
            f_out.write("\n".join(out_lines))
        print(f"\nMetrics written to {out_path}")
    except Exception as e:
        print("[WARNING] Could not write metrics file:", e)


if __name__ == "__main__":
    main() 